import sls
import nls

import numpy as np


def get_optimizer(opt, params, n_batches_per_epoch=None):
    """
    opt: name or dict
    params: model parameters
    n_batches_per_epoch: b/n
    """
    if isinstance(opt, dict):
        opt_name = opt["name"]
        opt_dict = opt
    else:
        opt_name = opt
        opt_dict = {}

    # ===============================================
    # our optimizers   
    n_batches_per_epoch = opt_dict.get("n_batches_per_epoch") or n_batches_per_epoch    
    if opt_name in ["sgd_armijo", "sls"]:
        if opt_dict.get("infer_c"):
            c = (1e-3) * np.sqrt(n_batches_per_epoch)
        
        opt = sls.Sls(params,
                    c = opt_dict.get("c") or 0.1,
                    n_batches_per_epoch=n_batches_per_epoch,
                    line_search_fn="armijo")

    elif opt_name in ["nls", "aloe"]:
        gamma_incr = 1.25
        gamma_decr = 0.7
        if "gamma_incr" in opt_dict:
            gamma_incr = opt_dict["gamma_incr"]
        if "gamma_decr" in opt_dict:
            gamma_decr = opt_dict["gamma_decr"]
        opt = nls.Nls(params, n_batches_per_epoch=n_batches_per_epoch, gamma_decr=gamma_decr, gamma_incr=gamma_incr)

    else:
        raise ValueError("opt %s does not exist..." % opt_name)

    return opt
